{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "g831xANXh2HY",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "# http://pytorch.org/\n",
    "from os.path import exists\n",
    "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n",
    "platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n",
    "cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\\.\\([0-9]*\\)\\.\\([0-9]*\\)$/cu\\1\\2/'\n",
    "accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "nNemnO18h6PV",
    "colab_type": "code",
    "outputId": "a770e281-5e28-466d-9e04-db845e0ba90e",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1971.0
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 cost = 2.092809\n",
      "Epoch: 0002 cost = 0.255108\n",
      "Epoch: 0003 cost = 0.273580\n",
      "Epoch: 0004 cost = 0.072622\n",
      "Epoch: 0005 cost = 0.023565\n",
      "Epoch: 0006 cost = 0.034211\n",
      "Epoch: 0007 cost = 0.003267\n",
      "Epoch: 0008 cost = 0.003564\n",
      "Epoch: 0009 cost = 0.024341\n",
      "Epoch: 0010 cost = 0.000667\n",
      "Epoch: 0011 cost = 0.000384\n",
      "Epoch: 0012 cost = 0.000739\n",
      "Epoch: 0013 cost = 0.000474\n",
      "Epoch: 0014 cost = 0.000306\n",
      "Epoch: 0015 cost = 0.003650\n",
      "Epoch: 0016 cost = 0.000577\n",
      "Epoch: 0017 cost = 0.001067\n",
      "Epoch: 0018 cost = 0.000163\n",
      "Epoch: 0019 cost = 0.000262\n",
      "Epoch: 0020 cost = 0.000920\n",
      "ich mochte ein bier P -> ['i', 'want', 'a', 'beer', 'E']\n",
      "first head of last state enc_self_attns\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAH2CAYAAAClRS9UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAG6RJREFUeJzt3XmMVfX98PHPDAxgqkKiIjgKWsUl\nFRFQgRYFCWoES1CxSl1arHvqXhEx1h9uaGtjVYrx0RqsWhdETas+KipuxbobwarICLKJC3XJIOvM\nff4wzu8hqAwyw/nc8fX6azh3Zu7nfkPmfe85555bUSqVSgEApFBZ9AAAwP8SZgBIRJgBIBFhBoBE\nhBkAEhFmAEhEmAEgEWEGgESEmbLxxRdfxD333BPXXnttw7a5c+cWNxBAMxBmysLzzz8fAwcOjNtv\nvz1uvvnmiIhYuHBhHHroofHUU08VOxxAExJmysIf//jHuOCCC+If//hHVFRUREREdXV1XH311Wu8\nggYod8JMWXjvvffisMMOi4hoCHNExP777293NtCiCDNloWPHjrFgwYK1tr/22mux2WabFTARQPNo\nXfQA0BjDhg2Lk046KY477rior6+PRx55JN5+++24884747jjjit6PIAmU+FjHykHpVIpbr311rj3\n3ntj3rx50a5du+jSpUuMHDkyDj/88KLHA2gywkxZeP/996Nr165rbV+5cmXMmDEjevfuXcBUAE3P\nMWbKwrBhw75x+5dffhknnHDCRp4GoPk4xkxq99xzT9x9992xatWqGDFixFq3f/LJJ9GhQ4cCJgNo\nHsJMagcffHC0b98+zjnnnBg4cOBat7dt2zYGDx688QcDaCaOMVMWHnrooRg6dGjRYwA0O2GmbLz+\n+usxe/bsWLFixVq3HX300QVMBND0hJmycPnll8dtt90WW2yxRbRt23aN2yoqKuKJJ54oaDKApiXM\nlIWePXvGhAkT4mc/+1nRowA0K2+Xoixsttlmsffeexc9BkCzE2bKwhlnnBG33HJL2MEDtHR2ZZPW\n4YcfvsYnSc2fPz9atWoVnTt3XmN7RMS99967sccDaBbex0xa+++/f9EjAGx0XjFTVurq6qJVq1YR\nEbFs2bLYZJNNCp4IoGk5xkxZmDt3bvz85z+PqVOnNmy7884745BDDon333+/wMkAmpYwUxbGjRsX\n++yzT/Tr169h2xFHHBH77rtvjBs3rsDJAJqWXdmUhd69e8cLL7wQrVuveVrEqlWrom/fvvHKK68U\nNBl8uwULFsS2225b9BiUGa+YKQvt27eP2bNnr7V9xowZsemmmxYwEazbsGHDoq6urugxKDPOyqYs\nHHfccTFq1KgYMmRIbLvttlFfXx9z5syJRx55JM4777yix4NvdPTRR8d1110XJ554oieQNJpd2ZSN\nxx9/PO67776YP39+VFRUxHbbbReHH354DBo0qOjR4BsdfPDB8cknn8TSpUtj0003bXhHwdeef/75\ngiYjM2EGaCb333//d95+6KGHbqRJKCfCTNm477774uGHH46FCxdGRUVFdOnSJQ4//PA44IADih4N\n1mnVqlVRVVVV9BiUAceYm0FdXV2sXLlyre0uhvH9TZw4MW699dYYMmRI9O/fPyIi3nvvvRgzZkws\nXbo0hg8fXvCE5evDDz+MSZMmRU1NTSxfvnyt2//2t78VMFXLsHLlyvjLX/4SU6ZMic8//zxmzJgR\ntbW1cdlll8VFF10UP/rRj4oekYSEuQk9//zzMW7cuJg3b943ftjCW2+9VcBULcPkyZPjxhtvjD33\n3HON7cOGDYtx48YJ8wY4++yz47PPPos+ffpEu3btih6nRbn88svjzTffjN///vfxu9/9LiIi6uvr\n49NPP40rrrgiLr/88oInJCNhbkIXXHBB9OvXL8aOHesPXBP77LPPonv37mtt79mzZyxcuLCAiVqO\nt956K6ZNmxYdOnQoepQW57HHHov7778/OnXq1PDBK5tvvnmMHz8+hg0bVvB0ZCXMTejzzz+PcePG\nRZs2bYoepcXZfvvt44knnogDDzxwje3Tpk1zAYcNtP3223uvbTOpq6uLrbbaaq3tbdq0iaVLlxYw\nEeVAmJvQoEGD4u2334499tij6FFanNNPPz1OP/306NOnT+y4444R8dUx5hdeeCHGjx9f8HTl7bzz\nzosLL7wwjjzyyKiuro7KyjWvO7TTTjsVNFn5+8lPfhI33XRTnHLKKQ3bli5dGldeeaW/E3wrZ2Vv\noDvuuKPh62XLlsWUKVNi4MCB3/gq7uijj96Yo7U4s2bNiilTpsT8+fNj5cqV0aVLlxg+fLg/cBto\n1113XWtbRUVFlEqlqKiocG7EBpg1a1accMIJsXr16vj000/jxz/+cSxcuDC22mqrmDhxYnTr1q3o\nEUlImDdQYy9uUVFREU888UQzTwPrb13H6KurqzfSJC3T8uXLY9q0aTF//vxo165ddO3aNfr377/W\nxUbga8JMWViwYEFMmjQp3n///VixYsVat3tLD9BSOMbchEqlUkyaNCl69eoVPXr0iIiIRx99NBYs\nWBCjRo1a69gdjXfGGWfEqlWrYp999nFyXRMYOHBgPPXUUxER0bdv34Yzhr+Jy0auH2vLhhLmJvSH\nP/whHn/88dhrr70atm255ZZx/fXXx5IlS2L06NEFTlfe5syZE88995wLMjSRs88+u+Hr888/PyIi\nvvzyy2jTps1aH63J+vm2ta2qqoqPP/44OnbsaI03UE1NTTz11FPRqlWrGDx4cIt7Z4Zd2U2of//+\nMWXKlNh6663X2P7hhx/GiBEj4tlnny1osvJ3yimnxG9/+9vYfffdix6lxfnss8/i0ksvjUceeSQq\nKipi5syZ8d///jfOPPPM+NOf/hQdO3YsesSytWjRojj//PPjlVdeiVKpFKVSKVq3bh0DBgyIiy++\n2Np+D9OnT4+TTz45tt9++6ivr49FixbFLbfcEj179ix6tCYjzE1or732iqeffnqtV3WfffZZDBo0\nKF599dWCJit/ixcvjlGjRsVuu+0WW2+99Vq7B+2N+P7OPffcqK2tjTPOOCNGjhwZb7zxRixfvjwu\nueSSqK2tjeuuu67oEcvWscceG1VVVXH88cdHly5dolQqxfvvvx+33nprlEqluOWWW4oeseyMHDky\nhg4dGsccc0xERNx2223x2GOPxW233VbwZE3H/pQm1L9//7jgggvi5JNPjurq6obPDJ44cWIMHDiw\n6PHK2oUXXhiLFy+OzTffPD7++OM1bvuuY3is2zPPPBNTp06NDh06NKxlu3btYuzYsTF48OCCpytv\nM2fOjGeffXaNz2Lu2rVr9OrVK/bbb78CJytfs2fPjl/84hcN/x4xYkRMmDChwImanjA3oYsvvjgu\nvPDCOOKIIxp2W1VWVsbgwYPj0ksvLXq8svbyyy/Hww8/7K07zaB169bfeAnZlStXfuMZ8DRely5d\nora2do0wR3x1zQP/l7+flStXrnEC6CabbPKNH75SzoR5A9XU1DRciWrJkiVxzjnnNJyB/fVRgg4d\nOsQHH3zgCkoboFu3btG2bduix2iRevbsGVdddVXDhyxERMybNy8uu+yy6NevX4GTlafZs2c3fD1q\n1Kg455xz4pe//GXsuOOOUVFREXPmzIk777wzTjvttAKnJDPHmDfQHnvsEW+88UZEfHUFpW/areoK\nShvuwQcfjLvuuiuGDh0anTp1WuutZwMGDChosvK3ePHiOPXUU2PWrFlRV1cX7dq1ixUrVsRee+0V\nV1999VonM/Ldvv47sK4/rf4mfD+77757jB07do31HT9+/FrbyvlKi8K8gRYtWhTbbLNNRLiCUnP6\npstGfs0fuKYxY8aMmD9/frRt2za6du1qD8/3tD6fduZvwvprzNUWy/1Ki8IMAIm4FBUAJCLMAJCI\nMANAIsIMAIkIMwAkUnYXGKlf3K3oERqtYouHorRkaNFjtEjWtvmU29rWleqLHqHRWm/5f2P1JwcX\nPUajDanuVfQIjfZ/3vhTnLTHuUWP0WhT6yd/621eMTejiqqdix6hxbK2zcfaNh9r23x22L1L0SM0\nGWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYA\nSESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgkVRhXrhwYXTv3j1mz55d9CgA\nUIjWRQ/w/6uuro4ZM2YUPQYAFCbVK2YA+KFLFeYFCxbELrvsErNmzSp6FAAoRKowA8APXapjzI1R\nscVDUVG1c9FjNFplp3eLHqHFsrbNp5zWttxeXVR1ril6hEabWl/0BOtnav3kokdoEmUX5tKSoVEq\neohGquz0btQv7lb0GC2StW0+5ba2daXyqUdV55pY9cGORY/RaEOqexU9QqNNrZ8cB1QeUfQYjfZd\nTyLK7ckmALRowgwAiQgzACQizACQSKqTv7bddtt45513ih4DAArjFTMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0Ai\nwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgz\nACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQ\niDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLC\nDACJCDMAJCLMAJCIMANAIsIMAImkCfN9990XS5YsKXoMAChUijDX1dXF+PHjhRmAH7x1hnngwIEx\nderUhn+fcMIJMWTIkIZ/v/3229G9e/eoqamJk08+Ofr06RN77713nHrqqfHRRx81fN8uu+wSjz76\naIwcOTL23HPPGDZsWLzzzjsREdG7d+/44osv4rDDDos///nPTfn4AKCsrDPMffr0iVdffTUivnpl\n++abb8by5cvj008/jYiIl19+OXr27BmXXHJJbLbZZvHss8/Gk08+GbW1tXHVVVet8btuvvnmuOKK\nK2L69OnRvn37uP766yMi4sEHH4yIr3Znn3XWWU36AAGgnLRe1zf07ds37r777oiIePPNN6Nr167R\nqVOneOWVV2Lw4MHx8ssvR79+/WLUqFEREdGmTZto06ZNDBo0KO666641ftchhxwSO+ywQ0RE7Lff\nfnHfffet98AVWzwUFVU7r/fPFaWy07tFj9BiWdvmU05rm+J43Hqo6lxT9AiNNrW+6AnWz9T6yUWP\n0CQaFeaLLrooVqxYES+99FLstdde0bFjxzXCPGrUqJg5c2Zcc8018fbbb8fKlSujvr4+tt566zV+\n17bbbtvw9SabbBIrVqxY74FLS4ZGab1/qhiVnd6N+sXdih6jRbK2zafc1rauVD71qOpcE6s+2LHo\nMRptSHWvokdotKn1k+OAyiOKHqPRvutJxDqfbHbu3Dm22WabmDFjRrz00kvRu3fv6NmzZ7zyyisx\nb968WL58eXTp0iVOOumk2H333WPatGkxY8aMGD169Np3Vlluz20BYONqVCn79u0bL7/8crz22mvR\nq1ev2G233WLOnDnx3HPPxT777BNz586NpUuXxm9+85vYfPPNI+Kr3d4AwPppdJgfeOCB6NixY7Rv\n3z5at24du+66a9xxxx3Rr1+/2GabbaKysjJee+21WLZsWdx9990xZ86c+Pzzz2P58uXr/P3t2rWL\niIi5c+dGbW3thj0iAChjjQpznz59Yu7cudG7d++Gbb169YrZs2fHT3/609h6661j9OjRcfHFF8eA\nAQOipqYmrrvuuujQoUMceOCB6/z9W265ZRx00EFxzjnnxNVXX/39Hw0AlLmKUqlULudSRUSU1Ukp\n5XYSTTmxts2n3NbWyV/Nx8lfzWeDTv4CADYeYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhE\nmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEG\ngESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQAS\nEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESY\nASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaA\nRIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIR\nZgBIRJgBIBFhBoBECg3zm2++Gccee2zsvffe0bdv3xg9enTU1tYWORIAFKrQMJ911lnRo0eP+Pe/\n/x0PPvhgzJw5M2666aYiRwKAQlWUSqVSUXe+dOnSqKqqijZt2kRExGWXXRZz5syJv/71r9/6M6VV\ns6KiaueNNSIAbFSti7zz559/PiZOnBhz5syJ1atXR11dXfTu3fs7f6a0ZGgU9kxiPVV2ejfqF3cr\neowWydo2n3Jb27pSfdEjNFpV55pY9cGORY/RaEOqexU9QqNNrZ8cB1QeUfQYjTa1fvK33lbYruya\nmpo488wz45BDDonp06fHjBkz4phjjilqHABIobBXzG+99Va0atUqRo0aFRUVFRHx1clglZVOFAfg\nh6uwCm633XaxcuXKmDlzZtTW1saECRNi2bJl8fHHH0ddXV1RYwFAoQoLc48ePeLXv/51jBo1Kg46\n6KCoqqqKK664Ir744gu7tAH4wSr05K8xY8bEmDFj1tg2ffr0gqYBgOI5oAsAiQgzACQizACQiDAD\nQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJ\nCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLM\nAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANA\nIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQSKPDvGDBgthll11i1qxZzTkPAPygecUMAIkIMwAkst5h/s9//hPD\nhg2Lnj17xjHHHBMLFy6MiIgXX3wxjjrqqOjVq1f0798/rrnmmqivr2/4ub///e8xZMiQ6NGjRxx0\n0EHx8MMPN9x27LHHxlVXXRXDhw+PX/3qV03wsACgPK13mO+666644YYb4umnn46qqqo4//zzY/Hi\nxXHyySfHiBEj4sUXX4xJkybFP//5z7jnnnsiIuLxxx+Pa6+9Nq688sp49dVXY8yYMTF69Oioqalp\n+L0PPfRQXHzxxTFp0qQme3AAUG5ar+8PjBw5MqqrqyMi4vjjj48TTzwxpkyZEjvssEOMGDEiIiJ2\n2mmnOPbYY+P++++Po446Ku6555447LDDYo899oiIiP333z/69+8fDzzwQJx77rkREdG9e/fo2bPn\nOu+/YouHoqJq5/UduzCVnd4teoQWy9o2n3Ja23I7HlfVuWbd35TE1Pp1f08mU+snFz1Ck1jvMO+0\n004NX3fp0iVKpVK88MIL8dZbb0X37t0bbiuVSrHllltGRMS8efPiX//6V9x+++1r3L7ZZps1/Hub\nbbZp1P2XlgyN0voOXZDKTu9G/eJuRY/RIlnb5lNua1tXKp96VHWuiVUf7Fj0GI02pLpX0SM02tT6\nyXFA5RFFj9Fo3/UkYr3DXFn5v89PS6WvElldXR1t2rSJm2+++Rt/pl27dnHmmWfGSSed9O2DtF7v\nUQCgxVnvvUBz5sxp+HrevHnRqlWr2HXXXePdd99d42SvJUuWxPLlyyPiq1fW77zzzhq/Z9GiRWt8\nPwDwPcJ85513xkcffRS1tbVx6623xoABA2L48OFRW1sb119/fSxbtiwWLVoUJ554Ytx4440R8dVx\n6UcffTQef/zxWL16dbz66qsxfPjweOGFF5r8AQFAOVvvMI8cOTKOP/742HfffWP16tXxP//zP9G+\nffu44YYb4plnnok+ffrEkUceGXvvvXecdtppERHRr1+/GDt2bIwfPz569eoVY8eOjfPOOy/69evX\n5A8IAMpZRenrA8VlopxOSim3k2jKibVtPuW2tk7+aj5O/mo+33XyV7m90wAAWjRhBoBEhBkAEhFm\nAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEg\nEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYA\nSESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASAR\nYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZ\nABIRZgBIRJgBIBFhBoBEhBkAEhFmAEikdRF3OmjQoPjwww+jsnLt5wVjx46NkSNHFjAVABSvkDBH\nRFxwwQVxzDHHFHX3AJCSXdkAkIgwA0AiFaVSqbSx7/S7jjG//vrr0apVq2/92dKqWVFRtXNzjgcA\nhSm7Y8ylJUNjoz+T+J4qO70b9Yu7FT1Gi2Rtm0+5rW1dqb7oERqtqnNNrPpgx6LHaLQh1b2KHqHR\nptZPjgMqjyh6jEabWj/5W2+zKxsAEhFmAEhEmAEgkcKOMY8fPz6uuuqqtbYPGDAgJkyYUMBEAFC8\nQsL85JNPFnG3AJCeXdkAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIM\nAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAk\nIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkEhF\nqVQqFT0EAPAVr5gBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgET+Hy6qs27DXv9XAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first head of last state dec_self_attns\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAH2CAYAAAClRS9UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHD9JREFUeJzt3XlwVYXd8PFfwmqrwEwVFRBs1apT\nFQFlaVEQUR7RUlSsRdEWq7hM3SsqjvXBqqi101atfX21FqrWFezUZVDAvVh3R7AqBUG24kJFnyAQ\nSO77h2OelwElkcTzu/Hz+Ss5Jzf53TNwvzlLzq0olUqlAABSqCx6AADgfwkzACQizACQiDADQCLC\nDACJCDMAJCLMAJCIMANAIsJM2fjoo4/i7rvvjt/97nd1yxYsWFDcQABNQJgpC88880wMHDgwbrvt\ntrj55psjImLJkiVx+OGHx+OPP17scACNSJgpC7/61a/iwgsvjL/97W9RUVERERGdO3eOa665Zr09\naIByJ8yUhbfeeiuOOOKIiIi6MEdEHHDAAQ5nA82KMFMWOnbsGIsXL95g+csvvxxbbbVVARMBNI2W\nRQ8A9TFs2LAYM2ZMHH/88VFbWxtTp06NN954I+644444/vjjix4PoNFUeNtHykGpVIpJkybFvffe\nGwsXLoy2bdtG165dY+TIkXHkkUcWPR5AoxFmysLbb78d3bp122B5dXV1zJo1K3r16lXAVACNzzlm\nysKwYcM2uvzjjz+OE0888UueBqDpOMdManfffXfcddddsXbt2hgxYsQG699///3o0KFDAZMBNA1h\nJrVDDjkk2rdvH+ecc04MHDhwg/Vt2rSJwYMHf/mDATQR55gpCw8++GAceuihRY8B0OSEmbLxyiuv\nxNy5c2PNmjUbrDv22GMLmAig8QkzZeHyyy+PW2+9Nb7xjW9EmzZt1ltXUVERM2bMKGgygMYlzJSF\nHj16xPXXXx/f+973ih4FoEn5cynKwlZbbRX77rtv0WMANDlhpiycccYZccstt4QDPEBz51A2aR15\n5JHrvZPUokWLokWLFrH99tuvtzwi4t577/2yxwNoEv6OmbQOOOCAokcA+NLZY6as1NTURIsWLSIi\nYtWqVbHFFlsUPBFA43KOmbKwYMGC+P73vx/Tpk2rW3bHHXfEYYcdFm+//XaBkwE0LmGmLIwfPz56\n9+4d/fr1q1t21FFHxX777Rfjx48vcDKAxuVQNmWhV69e8eyzz0bLlutfFrF27dro27dvvPjiiwVN\nBp9t8eLF0aVLl6LHoMzYY6YstG/fPubOnbvB8lmzZsWWW25ZwESwacOGDYuampqix6DMuCqbsnD8\n8cfH6NGjY+jQodGlS5eora2N+fPnx9SpU+O8884rejzYqGOPPTauvfbaOOmkk/wCSb05lE3ZmD59\nekyZMiUWLVoUFRUVscMOO8SRRx4ZgwYNKno02KhDDjkk3n///Vi5cmVsueWWdX9R8KlnnnmmoMnI\nTJgBmsh99933uesPP/zwL2kSyokwUzamTJkSDz30UCxZsiQqKiqia9euceSRR8ZBBx1U9GiwSWvX\nro1WrVoVPQZlwDnmJlBTUxPV1dUbLHczjC/uhhtuiEmTJsXQoUOjf//+ERHx1ltvxQUXXBArV66M\n4cOHFzxh+XrnnXdi4sSJMW/evFi9evUG6//85z8XMFXzUF1dHb///e9j8uTJ8eGHH8asWbOiqqoq\nLrvssrj44ovj61//etEjkpAwN6Jnnnkmxo8fHwsXLtzomy28/vrrBUzVPNxzzz1x4403xt57773e\n8mHDhsX48eOFeTOcffbZsWLFiujTp0+0bdu26HGalcsvvzxee+21+MUvfhE///nPIyKitrY2Pvjg\ng7jiiivi8ssvL3hCMhLmRnThhRdGv379Yty4cV7gGtmKFStizz333GB5jx49YsmSJQVM1Hy8/vrr\n8dhjj0WHDh2KHqXZeeSRR+K+++6L7bbbru6NV9q1axcTJkyIYcOGFTwdWQlzI/rwww9j/Pjx0bp1\n66JHaXZ23HHHmDFjRhx88MHrLX/sscfcwGEz7bjjjv7WtonU1NTENttss8Hy1q1bx8qVKwuYiHIg\nzI1o0KBB8cYbb8Ree+1V9CjNzumnnx6nn3569OnTJ3baaaeI+OQc87PPPhsTJkwoeLrydt5558VF\nF10URx99dHTu3DkqK9e/79DOO+9c0GTl7zvf+U7cdNNNccopp9QtW7lyZVx55ZVeJ/hMrsreTLff\nfnvdx6tWrYrJkyfHwIEDN7oXd+yxx36ZozU7c+bMicmTJ8eiRYuiuro6unbtGsOHD/cCt5l22223\nDZZVVFREqVSKiooK10Zshjlz5sSJJ54Y69atiw8++CC+9a1vxZIlS2KbbbaJG264IXbZZZeiRyQh\nYd5M9b25RUVFRcyYMaOJp4GG29Q5+s6dO39JkzRPq1evjsceeywWLVoUbdu2jW7dukX//v03uNkI\nfEqYKQuLFy+OiRMnxttvvx1r1qzZYL0/6QGaC+eYG1GpVIqJEydGz549o3v37hER8fDDD8fixYtj\n9OjRG5y7o/7OOOOMWLt2bfTu3dvFdY1g4MCB8fjjj0dERN++feuuGN4Yt41sGNuWzSXMjejqq6+O\n6dOnxz777FO3bOutt47rrrsuli9fHmPHji1wuvI2f/78ePrpp92QoZGcffbZdR+ff/75ERHx8ccf\nR+vWrTd4a00a5rO2batWreK9996Ljh072sabad68efH4449HixYtYvDgwc3uLzMcym5E/fv3j8mT\nJ8e222673vJ33nknRowYEU899VRBk5W/U045JX72s5/FHnvsUfQozc6KFSvil7/8ZUydOjUqKipi\n9uzZ8Z///CfOPPPM+PWvfx0dO3YsesSytXTp0jj//PPjxRdfjFKpFKVSKVq2bBkDBgyISy65xLb9\nAmbOnBknn3xy7LjjjlFbWxtLly6NW265JXr06FH0aI1GmBvRPvvsE0888cQGe3UrVqyIQYMGxUsv\nvVTQZOVv2bJlMXr06Nh9991j22233eDwoKMRX9y5554bVVVVccYZZ8TIkSPj1VdfjdWrV8ell14a\nVVVVce211xY9Ytk67rjjolWrVnHCCSdE165do1Qqxdtvvx2TJk2KUqkUt9xyS9Ejlp2RI0fGoYce\nGqNGjYqIiFtvvTUeeeSRuPXWWwuerPE4ntKI+vfvHxdeeGGcfPLJ0blz57r3DL7hhhti4MCBRY9X\n1i666KJYtmxZtGvXLt5777311n3eOTw27cknn4xp06ZFhw4d6rZl27ZtY9y4cTF48OCCpytvs2fP\njqeeemq992Lu1q1b9OzZM/bff/8CJytfc+fOjR/+8Id1n48YMSKuv/76AidqfMLciC655JK46KKL\n4qijjqo7bFVZWRmDBw+OX/7yl0WPV9ZeeOGFeOihh/zpThNo2bLlRm8hW11dvdEr4Km/rl27RlVV\n1Xphjvjkngf+LX8x1dXV610AusUWW2z0zVfKmTBvpnnz5tXdiWr58uVxzjnn1F2B/elZgg4dOsS/\n//1vd1DaDLvssku0adOm6DGapR49esRVV11V9yYLERELFy6Myy67LPr161fgZOVp7ty5dR+PHj06\nzjnnnDjmmGNip512ioqKipg/f37ccccdcdpppxU4JZk5x7yZ9tprr3j11Vcj4pM7KG3ssKo7KG2+\nBx54IO6888449NBDY7vtttvgT88GDBhQ0GTlb9myZXHqqafGnDlzoqamJtq2bRtr1qyJffbZJ665\n5poNLmbk8336OrCpl1avCV/MHnvsEePGjVtv+06YMGGDZeV8p0Vh3kxLly6NTp06RYQ7KDWljd02\n8lNe4BrHrFmzYtGiRdGmTZvo1q2bIzxfUEPe7cxrQsPV526L5X6nRWEGgETcigoAEhFmAEhEmAEg\nEWEGgESEGQASKbsbjNQu26XoEeqt4hsPRmn5oUWPUW9DOu1d9Aj19n9f/XWM2evcosdolmzbpmPb\nNp1y27bTau/5zHX2mJtQRatvFz1Cs/XNPboWPUKzZds2Hdu26TSnbSvMAJCIMANAIsIMAIkIMwAk\nIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgw\nA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwA\niQgzACQizACQiDADQCLCDACJCDMAJJIqzEuWLIk999wz5s6dW/QoAFCIlkUP8P/r3LlzzJo1q+gx\nAKAwqfaYAeCrLlWYFy9eHLvuumvMmTOn6FEAoBCpwgwAX3UVpVKpVPQQn1q8eHEceOCBcf/998e3\nv/3tjX5Nae2cqGi18XUAUO5SXfxVH6Xlh0aa3yQ2oXK7f0Xtsl2KHqPehnTau+gR6m1a7T1xUOVR\nRY/RLNm2Tce2bTrltm2n1d7zmescygaARIQZABIRZgBIRJgBIJFUF3916dIl3nzzzaLHAIDC2GMG\ngESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQAS\nEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESY\nASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgERaFj1AQ1XVri56hHprF+U1\nb0WbNkWP0CDlNG9pzZqiRwDKhD1mAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBI\nRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFh\nBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkA\nEhFmAEhEmAEgkTRhnjJlSixfvrzoMQCgUCnCXFNTExMmTBBmAL7yNhnmgQMHxrRp0+o+P/HEE2Po\n0KF1n7/xxhux5557xrx58+Lkk0+OPn36xL777hunnnpqvPvuu3Vft+uuu8bDDz8cI0eOjL333juG\nDRsWb775ZkRE9OrVKz766KM44ogj4re//W1jPj8AKCubDHOfPn3ipZdeiohP9mxfe+21WL16dXzw\nwQcREfHCCy9Ejx494tJLL42tttoqnnrqqXj00UejqqoqrrrqqvW+18033xxXXHFFzJw5M9q3bx/X\nXXddREQ88MADEfHJ4eyzzjqrUZ8gAJSTlpv6gr59+8Zdd90VERGvvfZadOvWLbbbbrt48cUXY/Dg\nwfHCCy9Ev379YvTo0RER0bp162jdunUMGjQo7rzzzvW+12GHHRbf/OY3IyJi//33jylTpjR44K9v\nMz1atNq1wY8rSrtOi4oeod4eWVX0BA3zyKrbih6h2ZpWe0/RIzRbtm3TaS7btl5hvvjii2PNmjXx\n/PPPxz777BMdO3ZcL8yjR4+O2bNnx29+85t44403orq6Ompra2Pbbbdd73t16dKl7uMtttgi1qxZ\n0+CBV743uMGPKUq7Tovio6U7FD1GvY3YaUDRI9TbI6tui4O3GFX0GPVW+gL/1osyrfaeOKjyqKLH\naJZs26ZTbtv2836J2OSh7O233z46deoUs2bNiueffz569eoVPXr0iBdffDEWLlwYq1evjq5du8aY\nMWNijz32iMceeyxmzZoVY8eO3fCHVaa41gwA0qpXKfv27RsvvPBCvPzyy9GzZ8/YfffdY/78+fH0\n009H7969Y8GCBbFy5cr46U9/Gu3atYuITw57AwANU+8w//Wvf42OHTtG+/bto2XLlrHbbrvF7bff\nHv369YtOnTpFZWVlvPzyy7Fq1aq46667Yv78+fHhhx/G6tWrN/n927ZtGxERCxYsiKqqqs17RgBQ\nxuoV5j59+sSCBQuiV69edct69uwZc+fOje9+97ux7bbbxtixY+OSSy6JAQMGxLx58+Laa6+NDh06\nxMEHH7zJ77/11lvHkCFD4pxzzolrrrnmiz8bAChzFaVSqVT0EA1RThdTufir6bj4q+mU20U05cS2\nbTrltm036+IvAODLI8wAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIM\nAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAk\nIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAk0rLoARpq\ny8q2RY/QIOU0b2nNmqJHaJBymxegPuwxA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCI\nMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIM\nAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAk\nIswAkIgwA0AiwgwAiQgzACQizACQSKFhfu211+K4446LfffdN/r27Rtjx46NqqqqIkcCgEIVGuaz\nzjorunfvHv/4xz/igQceiNmzZ8dNN91U5EgAUKiKUqlUKuqHr1y5Mlq1ahWtW7eOiIjLLrss5s+f\nH3/84x8/8zGltXOiotW3v6wRAeBL1bLIH/7MM8/EDTfcEPPnz49169ZFTU1N9OrV63MfU1p+aBT2\nm0QDVW73r6hdtkvRY9TbkE57Fz1CvU2rvScOqjyq6DGaJdu26di2Tafctu202ns+c11hh7LnzZsX\nZ555Zhx22GExc+bMmDVrVowaNaqocQAghcL2mF9//fVo0aJFjB49OioqKiLik4vBKitdKA7AV1dh\nFdxhhx2iuro6Zs+eHVVVVXH99dfHqlWr4r333ouampqixgKAQhUW5u7du8dPfvKTGD16dAwZMiRa\ntWoVV1xxRXz00UcOaQPwlVXoxV8XXHBBXHDBBestmzlzZkHTAEDxnNAFgESEGQASEWYASESYASAR\nYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZ\nABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBI\nRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgERaFj1AQ/1jdU3RI9Tbd6O85m2x9TeKHqFBymne2g//\np+gRGqSiVeuiR6i3Uk35/B+LiIjKFkVPUH+1ZbZtmwl7zACQiDADQCLCDACJCDMAJCLMAJCIMANA\nIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQSL3DvHjx4th1111jzpw5TTkPAHyl2WMGgESEGQASaXCY//nPf8aw\nYcOiR48eMWrUqFiyZElERDz33HPxox/9KHr27Bn9+/eP3/zmN1FbW1v3uL/85S8xdOjQ6N69ewwZ\nMiQeeuihunXHHXdcXHXVVTF8+PD48Y9/3AhPCwDKU4PDfOedd8Yf/vCHeOKJJ6JVq1Zx/vnnx7Jl\ny+Lkk0+OESNGxHPPPRcTJ06M+++/P+6+++6IiJg+fXr87ne/iyuvvDJeeumluOCCC2Ls2LExb968\nuu/74IMPxiWXXBITJ05stCcHAOWmolQqlerzhYsXL44DDzwwrr766vjBD34QERFPPfVUnHTSSXH6\n6afHjBkzYsqUKXVf/6c//SmmTp0ad911V4wZMyZ22mmnOP/88+vWn3LKKbHLLrvEueeeG8cdd1y0\na9cufv/7329yjo+r34yvtd61oc8TAMpCy4Y+YOedd677uGvXrlEqleLZZ5+N119/Pfbcc8+6daVS\nKbbeeuuIiFi4cGH8/e9/j9tuu2299VtttVXd5506darXz39l6SENHbkw393xrZi54FtFj1Fvl/Y+\nuOgR6m3qu/8n/qvjKUWPUW+1H/5P0SPU2yNrbo+D2xxb9Bj1VqqpKXqEepu27s44qOWPih6j/mrL\naNvW3hMHVR5V9Bj1Nq32ns9c1+AwV1b+79HvT3e2O3fuHK1bt46bb755o49p27ZtnHnmmTFmzJjP\nHqRlg0cBgGanweeY58+fX/fxwoULo0WLFrHbbrvFv/71r/Uu9lq+fHmsXr06Ij7Zs37zzTfX+z5L\nly5d7+sBgC8Q5jvuuCPefffdqKqqikmTJsWAAQNi+PDhUVVVFdddd12sWrUqli5dGieddFLceOON\nERExcuTIePjhh2P69Omxbt26eOmll2L48OHx7LPPNvoTAoBy1uAwjxw5Mk444YTYb7/9Yt26dfHf\n//3f0b59+/jDH/4QTz75ZPTp0yeOPvro2HfffeO0006LiIh+/frFuHHjYsKECdGzZ88YN25cnHfe\nedGvX79Gf0IAUM7qfWK3S5cudYejhw4dusH63r17x+TJkz/z8cccc0wcc8wxG11366231ncMAGjW\n3PkLABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIR\nZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgB\nIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABJpWfQADbWspn3RIzRI\nOc1b+nhV0SM0SDnNW7lF26JHaJBymrd2zZqiR2iQilbl87Jb0aJN0SM0SOXXvlb0CI3CHjMAJCLM\nAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANA\nIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJBIyyJ+6KBBg+Kdd96JysoN\nfy8YN25cjBw5soCpAKB4hYQ5IuLCCy+MUaNGFfXjASAlh7IBIBFhBoBEKkqlUunL/qGfd475lVde\niRYtWnzmYz+snhvtW+/clOMBQGHK7hzzjEVHNcE0TeOInV6OKfN6FD1Gvd3Yfa+iR6i3h6smxZAt\nf1z0GPVW0bKw/2oNNnXFH+O/Ovy06DHqrXbNmqJHqLdHVt0WB29RPtfWVHzOTlI25faa8HDVpM9c\n51A2ACQizACQiDADQCKFnfiaMGFCXHXVVRssHzBgQFx//fUFTAQAxSskzI8++mgRPxYA0nMoGwAS\nEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESY\nASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaA\nRIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASqSiVSqWihwAAPmGPGQASEWYA\nSESYASARYQaARIQZABIRZgBI5P8B1EjV+06cOJIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first head of last state dec_enc_attns\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAH2CAYAAAClRS9UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGzFJREFUeJzt3XmMVfX5+PFnhkVMVUhURFDQuqci\nmwi0KEhQoliCilWq2GLdU/eKgLF8sSrS2liXYvxqDVatC6KmVaOi4lYtCmgEKyIIshUXKppB9rm/\nP4zz/RFQBpnhPHd8vf4azp2Z+9xPyH3PPefccytKpVIpAIAUKoseAAD4P8IMAIkIMwAkIswAkIgw\nA0AiwgwAiQgzACQizACQiDBTNr744ot46KGH4qabbqrZNn/+/OIGAqgHwkxZeO2116J3795x7733\nxp133hkREYsXL47jjz8+XnjhhWKHA6hDwkxZ+MMf/hAjRoyIv//971FRUREREW3atIkbbrhhg1fQ\nAOVOmCkLH3zwQZxwwgkRETVhjog48sgj7c4GGhRhpiy0bNkyFi1atNH2N998M3bccccCJgKoH42L\nHgBqY8CAAXH22WfH6aefHtXV1fHUU0/FrFmz4v7774/TTz+96PEA6kyFj32kHJRKpbj77rvj4Ycf\njgULFkSzZs2ibdu2MXjw4DjxxBOLHg+gzggzZeHDDz+Mdu3abbR9zZo1MWPGjOjSpUsBUwHUPceY\nKQsDBgzY5PYvv/wyzjzzzG08DUD9cYyZ1B566KF48MEHY+3atTFo0KCNbv/000+jRYsWBUwGUD+E\nmdSOOeaYaN68eVx66aXRu3fvjW7fbrvtom/fvtt+MIB64hgzZeGJJ56I/v37Fz0GQL0TZsrGW2+9\nFXPmzInVq1dvdNupp55awEQAdU+YKQvXXntt3HPPPbHzzjvHdtttt8FtFRUV8dxzzxU0GUDdEmbK\nQqdOneLWW2+Nn/zkJ0WPAlCvvF2KsrDjjjtG165dix4DoN4JM2XhwgsvjLvuuivs4AEaOruySevE\nE0/c4JOkFi5cGI0aNYrdd999g+0REQ8//PC2Hg+gXngfM2kdeeSRRY8AsM15xUxZWb9+fTRq1Cgi\nIlauXBnbb799wRMB1C3HmCkL8+fPj5/+9KcxadKkmm33339/HHfccfHhhx8WOBlA3RJmysLo0aPj\nsMMOix49etRsO+mkk+Lwww+P0aNHFzgZQN2yK5uy0KVLl5gyZUo0brzhaRFr166N7t27x7Rp0wqa\nDL7ZokWLYo899ih6DMqMV8yUhebNm8ecOXM22j5jxozYYYcdCpgINm/AgAGxfv36osegzDgrm7Jw\n+umnx9ChQ+PYY4+NPfbYI6qrq2PevHnx1FNPxeWXX170eLBJp556atx8881x1lln+QOSWrMrm7Lx\n7LPPxiOPPBILFy6MioqK2HPPPePEE0+MPn36FD0abNIxxxwTn376aaxYsSJ22GGHmncUfO21114r\naDIyE2aAevLoo49+6+3HH3/8NpqEciLMlI1HHnkknnzyyVi8eHFUVFRE27Zt48QTT4yjjjqq6NFg\ns9auXRtNmjQpegzKgGPM9WD9+vWxZs2ajba7GMZ3N27cuLj77rvj2GOPjZ49e0ZExAcffBDDhw+P\nFStWxMCBAwuesHx99NFHMX78+Jg7d26sWrVqo9v/+te/FjBVw7BmzZr485//HBMnTozPP/88ZsyY\nEVVVVXHNNdfEVVddFT/4wQ+KHpGEhLkOvfbaazF69OhYsGDBJj9s4d133y1gqoZhwoQJcfvtt0fH\njh032D5gwIAYPXq0MG+FSy65JJYvXx7dunWLZs2aFT1Og3LttdfGO++8E7/97W/jN7/5TUREVFdX\nx2effRbXXXddXHvttQVPSEbCXIdGjBgRPXr0iJEjR3qCq2PLly+P9u3bb7S9U6dOsXjx4gImajje\nfffdmDx5crRo0aLoURqcZ555Jh599NFo1apVzQev7LTTTjFmzJgYMGBAwdORlTDXoc8//zxGjx4d\nTZs2LXqUBmevvfaK5557Lo4++ugNtk+ePNkFHLbSXnvt5b229WT9+vWx6667brS9adOmsWLFigIm\nohwIcx3q06dPzJo1Kw455JCiR2lwLrjggrjggguiW7dusc8++0TEV8eYp0yZEmPGjCl4uvJ2+eWX\nx5VXXhknn3xytGnTJiorN7zu0L777lvQZOXvRz/6Udxxxx1x7rnn1mxbsWJFXH/99Z4n+EbOyt5K\n9913X83XK1eujIkTJ0bv3r03+Sru1FNP3ZajNTizZ8+OiRMnxsKFC2PNmjXRtm3bGDhwoCe4rXTg\ngQdutK2ioiJKpVJUVFQ4N2IrzJ49O84888xYt25dfPbZZ/HDH/4wFi9eHLvuumuMGzcu9ttvv6JH\nJCFh3kq1vbhFRUVFPPfcc/U8DWy5zR2jb9OmzTaapGFatWpVTJ48ORYuXBjNmjWLdu3aRc+ePTe6\n2Ah8TZgpC4sWLYrx48fHhx9+GKtXr97odm/pARoKx5jrUKlUivHjx0fnzp2jQ4cOERHx9NNPx6JF\ni2Lo0KEbHbuj9i688MJYu3ZtHHbYYU6uqwO9e/eOF154ISIiunfvXnPG8Ka4bOSWsbZsLWGuQ7//\n/e/j2WefjUMPPbRm2y677BK33HJLLFu2LIYNG1bgdOVt3rx58corr7ggQx255JJLar6+4oorIiLi\nyy+/jKZNm2700ZpsmW9a2yZNmsQnn3wSLVu2tMZbae7cufHCCy9Eo0aNom/fvg3unRl2Zdehnj17\nxsSJE2O33XbbYPtHH30UgwYNipdffrmgycrfueeeG7/+9a/j4IMPLnqUBmf58uXxu9/9Lp566qmo\nqKiImTNnxn//+9+46KKL4o9//GO0bNmy6BHL1pIlS+KKK66IadOmRalUilKpFI0bN45evXrFqFGj\nrO138Oqrr8Y555wTe+21V1RXV8eSJUvirrvuik6dOhU9Wp0R5jp06KGHxosvvrjRq7rly5dHnz59\nYvr06QVNVv6WLl0aQ4cOjYMOOih22223jXYP2hvx3V122WVRVVUVF154YQwePDjefvvtWLVqVVx9\n9dVRVVUVN998c9Ejlq0hQ4ZEkyZN4owzzoi2bdtGqVSKDz/8MO6+++4olUpx1113FT1i2Rk8eHD0\n798/TjvttIiIuOeee+KZZ56Je+65p+DJ6o79KXWoZ8+eMWLEiDjnnHOiTZs2NZ8ZPG7cuOjdu3fR\n45W1K6+8MpYuXRo77bRTfPLJJxvc9m3H8Ni8l156KSZNmhQtWrSoWctmzZrFyJEjo2/fvgVPV95m\nzpwZL7/88gafxdyuXbvo3LlzHHHEEQVOVr7mzJkTP/vZz2r+PWjQoLj11lsLnKjuCXMdGjVqVFx5\n5ZVx0kkn1ey2qqysjL59+8bvfve7oscra1OnTo0nn3zSW3fqQePGjTd5Cdk1a9Zs8gx4aq9t27ZR\nVVW1QZgjvrrmgf/L382aNWs2OAF0++233+SHr5QzYd5Kc+fOrbkS1bJly+LSSy+tOQP766MELVq0\niP/85z+uoLQV9ttvv9huu+2KHqNB6tSpU4wdO7bmQxYiIhYsWBDXXHNN9OjRo8DJytOcOXNqvh46\ndGhceuml8fOf/zz22WefqKioiHnz5sX9998f559/foFTkpljzFvpkEMOibfffjsivrqC0qZ2q7qC\n0tZ7/PHH44EHHoj+/ftHq1atNnrrWa9evQqarPwtXbo0zjvvvJg9e3asX78+mjVrFqtXr45DDz00\nbrjhho1OZuTbff08sLmnVs8J383BBx8cI0eO3GB9x4wZs9G2cr7SojBvpSVLlkTr1q0jwhWU6tOm\nLhv5NU9wdWPGjBmxcOHC2G677aJdu3b28HxHW/JpZ54TtlxtrrZY7ldaFGYASMSlqAAgEWEGgESE\nGQASEWYASESYASCRsrvASPXS/YoeodYqdn4iSsv6Fz1Gg2Rt64+1rT/ltrb9WncseoRa+9+3/xhn\nH3JZ0WPU2qTqCd94m1fM9aiiyf5Fj9BgWdv6Y23rj7WtP3sf3LboEeqMMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0Ai\nwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgz\nACQizACQiDADQCLCDACJCDMAJCLMAJBIqjAvXrw42rdvH3PmzCl6FAAoROOiB/j/tWnTJmbMmFH0\nGABQmFSvmAHg+y5VmBctWhQHHHBAzJ49u+hRAKAQqcIMAN93qY4x10bFzk9ERZP9ix6j1ipbvV/0\nCA2Wta0/1rb+lNPaTqoueoItM6l6QtEj1ImyC3NpWf8oFT1ELVW2ej+ql+5X9BgNkrWtP9a2/pTb\n2vZr3bHoEWptUvWEOKrypKLHqLVv+yPCrmwASESYASARYQaARIQZABJJdfLXHnvsEe+9917RYwBA\nYbxiBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBE\nhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFm\nAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEg\nEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYA\nSESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgkTRhfuSRR2LZsmVFjwEAhUoR\n5vXr18eYMWOEGYDvvc2GuXfv3jFp0qSaf5955plx7LHH1vx71qxZ0b59+5g7d26cc8450a1bt+ja\ntWucd9558fHHH9d83wEHHBBPP/10DB48ODp27BgDBgyI9957LyIiunTpEl988UWccMIJ8ac//aku\nHx8AlJXNhrlbt24xffr0iPjqle0777wTq1atis8++ywiIqZOnRqdOnWKq6++Onbcccd4+eWX4/nn\nn4+qqqoYO3bsBr/rzjvvjOuuuy5effXVaN68edxyyy0REfH4449HxFe7sy+++OI6fYAAUE4ab+4b\nunfvHg8++GBERLzzzjvRrl27aNWqVUybNi369u0bU6dOjR49esTQoUMjIqJp06bRtGnT6NOnTzzw\nwAMb/K7jjjsu9t5774iIOOKII+KRRx7Z4oErdn4iKprsv8U/V5TKVu8XPUKDZW3rj7WtP+W0tpOq\ni55gy0yqnlD0CHWiVmG+6qqrYvXq1fHGG2/EoYceGi1bttwgzEOHDo2ZM2fGjTfeGLNmzYo1a9ZE\ndXV17Lbbbhv8rj322KPm6+233z5Wr169xQOXlvWP0hb/VDEqW70f1Uv3K3qMBsna1h9rW3/KbW37\nte5Y9Ai1Nql6QhxVeVLRY9Tat/0Rsdld2bvvvnu0bt06ZsyYEW+88UZ06dIlOnXqFNOmTYsFCxbE\nqlWrom3btnH22WfHwQcfHJMnT44ZM2bEsGHDNr6zyhTnmgFAWrUqZffu3WPq1Knx5ptvRufOneOg\ngw6KefPmxSuvvBKHHXZYzJ8/P1asWBG/+tWvYqeddoqIr3Z7AwBbptZhfuyxx6Jly5bRvHnzaNy4\ncRx44IFx3333RY8ePaJ169ZRWVkZb775ZqxcuTIefPDBmDdvXnz++eexatWqzf7+Zs2aRUTE/Pnz\no6qqauseEQCUsVqFuVu3bjF//vzo0qVLzbbOnTvHnDlz4sc//nHstttuMWzYsBg1alT06tUr5s6d\nGzfffHO0aNEijj766M3+/l122SX69esXl156adxwww3f/dEAQJmrKJVK5XIuVUREWZ04UW4nepQT\na1t/rG39Kbe1dfJX/dmqk78AgG1HmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgB\nIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBE\nhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFm\nAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEg\nEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYA\nSESYASCRQsP8zjvvxJAhQ6Jr167RvXv3GDZsWFRVVRU5EgAUqtAwX3zxxdGhQ4f417/+FY8//njM\nnDkz7rjjjiJHAoBCVZRKpVJRd75ixYpo0qRJNG3aNCIirrnmmpg3b1785S9/+cafKa2dHRVN9t9W\nIwLANtW4yDt/7bXXYty4cTFv3rxYt25drF+/Prp06fKtP1Na1j8K+0tiC1W2ej+ql+5X9BgNkrWt\nP9a2/pTb2vZr3bHoEWptUvWEOKrypKLHqLVJ1RO+8bbCdmXPnTs3LrroojjuuOPi1VdfjRkzZsRp\np51W1DgAkEJhr5jffffdaNSoUQwdOjQqKioi4quTwSornSgOwPdXYRXcc889Y82aNTFz5syoqqqK\nW2+9NVauXBmffPJJrF+/vqixAKBQhYW5Q4cO8ctf/jKGDh0a/fr1iyZNmsR1110XX3zxhV3aAHxv\nFXry1/Dhw2P48OEbbHv11VcLmgYAiueALgAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJ\nCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLM\nAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANA\nIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0Ai\ntQ7zokWL4oADDojZs2fX5zwA8L3mFTMAJCLMAJDIFof53//+dwwYMCA6deoUp512WixevDgiIl5/\n/fU45ZRTonPnztGzZ8+48cYbo7q6uubn/va3v8Wxxx4bHTp0iH79+sWTTz5Zc9uQIUNi7NixMXDg\nwPjFL35RBw8LAMrTFof5gQceiNtuuy1efPHFaNKkSVxxxRWxdOnSOOecc2LQoEHx+uuvx/jx4+Mf\n//hHPPTQQxER8eyzz8ZNN90U119/fUyfPj2GDx8ew4YNi7lz59b83ieeeCJGjRoV48ePr7MHBwDl\npvGW/sDgwYOjTZs2ERFxxhlnxFlnnRUTJ06MvffeOwYNGhQREfvuu28MGTIkHn300TjllFPioYce\nihNOOCEOOeSQiIg48sgjo2fPnvHYY4/FZZddFhER7du3j06dOm32/it2fiIqmuy/pWMXprLV+0WP\n0GBZ2/pjbetPOa3tpOrNf08mk6onFD1CndjiMO+77741X7dt2zZKpVJMmTIl3n333Wjfvn3NbaVS\nKXbZZZeIiFiwYEH885//jHvvvXeD23fccceaf7du3bpW919a1j9KWzp0QSpbvR/VS/creowGydrW\nH2tbf8ptbfu17lj0CLU2qXpCHFV5UtFj1Nq3/RGxxWGurPy/vd+l0leJbNOmTTRt2jTuvPPOTf5M\ns2bN4qKLLoqzzz77mwdpvMWjAECDs8XHmOfNm1fz9YIFC6JRo0Zx4IEHxvvvv7/ByV7Lli2LVatW\nRcRXr6zfe++9DX7PkiVLNvh+AOA7hPn++++Pjz/+OKqqquLuu++OXr16xcCBA6OqqipuueWWWLly\nZSxZsiTOOuusuP322yPiq+PSTz/9dDz77LOxbt26mD59egwcODCmTJlS5w8IAMrZFod58ODBccYZ\nZ8Thhx8e69ati//5n/+J5s2bx2233RYvvfRSdOvWLU4++eTo2rVrnH/++RER0aNHjxg5cmSMGTMm\nOnfuHCNHjozLL788evToUecPCADKWUXp6wPFZaKcTpwotxM9yom1rT/Wtv6U29o6+av+fNvJX678\nBQCJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMA\nJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCI\nMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIM\nAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAk\nIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgw\nA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIo2LuNM+ffrERx99FJWVG/9dMHLk\nyBg8eHABUwFA8QoJc0TEiBEj4rTTTivq7gEgJbuyASARYQaARCpKpVJpW9/ptx1jfuutt6JRo0bf\n+LOltbOjosn+9TkeABSm7I4xl5b1j23+l8R3VNnq/aheul/RYzRI1rb+WNv6U25r2691x6JHqLVJ\n1RPiqMqTih6j1iZVT/jG2+zKBoBEhBkAEhFmAEiksGPMY8aMibFjx260vVevXnHrrbcWMBEAFK+Q\nMD///PNF3C0ApGdXNgAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANA\nIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkUlEq\nlUpFDwEAfMUrZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgkf8HhgybVjMLh/kAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "'''\n",
    "  code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612\n",
    "  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch\n",
    "              https://github.com/JayParks/transformer\n",
    "'''\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "dtype = torch.FloatTensor\n",
    "# S: Symbol that shows starting of decoding input\n",
    "# E: Symbol that shows starting of decoding output\n",
    "# P: Symbol that will fill in blank sequence if current batch data size is short than time steps\n",
    "sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']\n",
    "\n",
    "# Transformer Parameters\n",
    "# Padding Should be Zero\n",
    "src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4}\n",
    "src_vocab_size = len(src_vocab)\n",
    "\n",
    "tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'S' : 5, 'E' : 6}\n",
    "number_dict = {i: w for i, w in enumerate(tgt_vocab)}\n",
    "tgt_vocab_size = len(tgt_vocab)\n",
    "\n",
    "src_len = 5\n",
    "tgt_len = 5\n",
    "\n",
    "d_model = 512  # Embedding Size\n",
    "d_ff = 2048 # FeedForward dimension\n",
    "d_k = d_v = 64  # dimension of K(=Q), V\n",
    "n_layers = 6  # number of Encoder of Decoder Layer\n",
    "n_heads = 8  # number of heads in Multi-Head Attention\n",
    "\n",
    "def make_batch(sentences):\n",
    "    input_batch = [[src_vocab[n] for n in sentences[0].split()]]\n",
    "    output_batch = [[tgt_vocab[n] for n in sentences[1].split()]]\n",
    "    target_batch = [[tgt_vocab[n] for n in sentences[2].split()]]\n",
    "    return Variable(torch.LongTensor(input_batch)), Variable(torch.LongTensor(output_batch)), Variable(torch.LongTensor(target_batch))\n",
    "\n",
    "def get_sinusoid_encoding_table(n_position, d_model):\n",
    "    def cal_angle(position, hid_idx):\n",
    "        return position / np.power(10000, 2 * (hid_idx // 2) / d_model)\n",
    "    def get_posi_angle_vec(position):\n",
    "        return [cal_angle(position, hid_j) for hid_j in range(d_model)]\n",
    "\n",
    "    sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])\n",
    "    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i\n",
    "    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1\n",
    "    return torch.FloatTensor(sinusoid_table)\n",
    "\n",
    "def get_attn_pad_mask(seq_q, seq_k):\n",
    "    batch_size, len_q = seq_q.size()\n",
    "    batch_size, len_k = seq_k.size()\n",
    "    # eq(zero) is PAD token\n",
    "    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking\n",
    "    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k\n",
    "\n",
    "def get_attn_subsequent_mask(seq):\n",
    "    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]\n",
    "    subsequent_mask = np.triu(np.ones(attn_shape), k=1)\n",
    "    subsequent_mask = torch.from_numpy(subsequent_mask).byte()\n",
    "    return subsequent_mask\n",
    "\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ScaledDotProductAttention, self).__init__()\n",
    "\n",
    "    def forward(self, Q, K, V, attn_mask):\n",
    "        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n",
    "        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.\n",
    "        attn = nn.Softmax(dim=-1)(scores)\n",
    "        context = torch.matmul(attn, V)\n",
    "        return context, attn\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
    "        self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
    "        self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
    "    def forward(self, Q, K, V, attn_mask):\n",
    "        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]\n",
    "        residual, batch_size = Q, Q.size(0)\n",
    "        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)\n",
    "        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]\n",
    "        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]\n",
    "        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]\n",
    "\n",
    "        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]\n",
    "\n",
    "        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n",
    "        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)\n",
    "        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]\n",
    "        output = nn.Linear(n_heads * d_v, d_model)(context)\n",
    "        return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]\n",
    "\n",
    "class PoswiseFeedForwardNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(PoswiseFeedForwardNet, self).__init__()\n",
    "        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)\n",
    "        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        residual = inputs # inputs : [batch_size, len_q, d_model]\n",
    "        output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))\n",
    "        output = self.conv2(output).transpose(1, 2)\n",
    "        return nn.LayerNorm(d_model)(output + residual)\n",
    "\n",
    "class EncoderLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(EncoderLayer, self).__init__()\n",
    "        self.enc_self_attn = MultiHeadAttention()\n",
    "        self.pos_ffn = PoswiseFeedForwardNet()\n",
    "\n",
    "    def forward(self, enc_inputs, enc_self_attn_mask):\n",
    "        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V\n",
    "        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]\n",
    "        return enc_outputs, attn\n",
    "\n",
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(DecoderLayer, self).__init__()\n",
    "        self.dec_self_attn = MultiHeadAttention()\n",
    "        self.dec_enc_attn = MultiHeadAttention()\n",
    "        self.pos_ffn = PoswiseFeedForwardNet()\n",
    "\n",
    "    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):\n",
    "        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)\n",
    "        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)\n",
    "        dec_outputs = self.pos_ffn(dec_outputs)\n",
    "        return dec_outputs, dec_self_attn, dec_enc_attn\n",
    "\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.src_emb = nn.Embedding(src_vocab_size, d_model)\n",
    "        self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(src_len+1, d_model),freeze=True)\n",
    "        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
    "\n",
    "    def forward(self, enc_inputs): # enc_inputs : [batch_size x source_len]\n",
    "        enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,0]]))\n",
    "        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)\n",
    "        enc_self_attns = []\n",
    "        for layer in self.layers:\n",
    "            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)\n",
    "            enc_self_attns.append(enc_self_attn)\n",
    "        return enc_outputs, enc_self_attns\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)\n",
    "        self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(tgt_len+1, d_model),freeze=True)\n",
    "        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])\n",
    "\n",
    "    def forward(self, dec_inputs, enc_inputs, enc_outputs): # dec_inputs : [batch_size x target_len]\n",
    "        dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(torch.LongTensor([[5,1,2,3,4]]))\n",
    "        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs)\n",
    "        dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)\n",
    "        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)\n",
    "\n",
    "        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)\n",
    "\n",
    "        dec_self_attns, dec_enc_attns = [], []\n",
    "        for layer in self.layers:\n",
    "            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)\n",
    "            dec_self_attns.append(dec_self_attn)\n",
    "            dec_enc_attns.append(dec_enc_attn)\n",
    "        return dec_outputs, dec_self_attns, dec_enc_attns\n",
    "\n",
    "class Transformer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Transformer, self).__init__()\n",
    "        self.encoder = Encoder()\n",
    "        self.decoder = Decoder()\n",
    "        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)\n",
    "    def forward(self, enc_inputs, dec_inputs):\n",
    "        enc_outputs, enc_self_attns = self.encoder(enc_inputs)\n",
    "        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)\n",
    "        dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size]\n",
    "        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns\n",
    "\n",
    "model = Transformer()\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "def showgraph(attn):\n",
    "    attn = attn[-1].squeeze(0)[0]\n",
    "    attn = attn.squeeze(0).data.numpy()\n",
    "    fig = plt.figure(figsize=(n_heads, n_heads)) # [n_heads, n_heads]\n",
    "    ax = fig.add_subplot(1, 1, 1)\n",
    "    ax.matshow(attn, cmap='viridis')\n",
    "    ax.set_xticklabels(['']+sentences[0].split(), fontdict={'fontsize': 14}, rotation=90)\n",
    "    ax.set_yticklabels(['']+sentences[2].split(), fontdict={'fontsize': 14})\n",
    "    plt.show()\n",
    "\n",
    "for epoch in range(20):\n",
    "    optimizer.zero_grad()\n",
    "    enc_inputs, dec_inputs, target_batch = make_batch(sentences)\n",
    "    outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)\n",
    "    loss = criterion(outputs, target_batch.contiguous().view(-1))\n",
    "    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "# Test\n",
    "predict, _, _, _ = model(enc_inputs, dec_inputs)\n",
    "predict = predict.data.max(1, keepdim=True)[1]\n",
    "print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])\n",
    "\n",
    "print('first head of last state enc_self_attns')\n",
    "showgraph(enc_self_attns)\n",
    "\n",
    "print('first head of last state dec_self_attns')\n",
    "showgraph(dec_self_attns)\n",
    "\n",
    "print('first head of last state dec_enc_attns')\n",
    "showgraph(dec_enc_attns)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Transformer-Torch.ipynb",
   "version": "0.3.2",
   "provenance": [],
   "collapsed_sections": []
  },
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3"
  },
  "accelerator": "GPU",
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}